-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Refactor Stateful operator and custom op #6928
Conversation
op.set_attr<FExecType>("FExecType", OpExecType); | ||
op.set_attr<FCreateOpState>("FCreateOpState", OpPropCreateLayerOp); | ||
op.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", LegacyOpForward); | ||
op.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", LegacyOpForward); | ||
if (reg->key_var_num_args.length() != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to set attr for FStatefulComputeEx<xpu>
for legacy ops? Do we have to modify the interface of operator class to have Forward/Backward with NDArray inputs & outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normally you use FStatefulCompute
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so legacy op won't support sparse NDArrays - they'll all fallback to dense before forward/backward?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. You need to move to new interface to add sparse support
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
include/mxnet/op_attr_types.h
Outdated
* | ||
* \note Register under "FStatefulCompute<cpu>" and "FStatefulCompute<gpu>" | ||
*/ | ||
using FStatefulCompute = std::function<void (const dmlc::any& state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const reference gives user a feeling that state won't be mutated, maybe a pointer type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For mutable state you need to use shared_ptr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know, but does this API implies that state can be mutated, or no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state it's self cannot be mutated. If it's a pointer the content can be mutated.
The thing is state can be copied around and you cannot modify it inplace.
Another option is to use shared_ptr<any> state
but that forces you to use a pointer and you cannot use make_shared to construct state inplace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be at least heavily documented in the comment
9766815
to
b0cf3ca
Compare
* refactor create layer * fix * refactor custom op * fix * fix * fix * fix * fix OpState * remove superfluous infershape * fix * fix * fix lint * fix * fix * fix * Update CMakeLists.txt * delete * fix * fix scala
@tqchen @eric-haibin-lin